from torch.utils.data import Dataset
from PIL import Image
import os
# from torchvision.prototype.datasets.utils import Dataset

class Mydata(Dataset):
    """定义自己的数据集时，需要集继承父类Dataset，重写以下三个函数"""
    def __init__(self,root_dir,label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir)
        self.image_path = os.listdir(self.path)
    def __getitem__(self, idx):
        image_name = self.image_path[idx]
        image_item_path = os.path.join(self.path,image_name)
        image = Image.open(image_item_path)
        label = self.label_dir
        return image,label
    def __len__(self):
        return len(self.image_path)

train_root_dir = "train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = Mydata(train_root_dir,ants_label_dir)
bees_dataset = Mydata(train_root_dir,bees_label_dir)

